package proxy

import (
	"errors"
	"gosh/config"
	"gosh/pkg/model"
	"io/ioutil"
	"os"
	"path"
	"regexp"
	"strings"
	"time"
)

var ProxyConfig = &Config{}

type Config struct {
	Host                 string
	Port                 int
	User                 string
	Password             string
	Key                  string
	KeyContent           string
	KeyPassphrase        string
	BastionHost          string
	BastionPort          int
	BastionUser          string
	BastionPassword      string
	BastionKey           string
	BastionKeyContent    string
	BastionKeyPassphrase string
	ConnectTimeout       time.Duration
	LocalAddress         string
	RemoteAddress        string
	Protocol             string
	Type                 string
	Labels               []string
	Condition            map[string]string
}

func (c *Config) labelToCondition() {
	for _, label := range c.Labels {
		match, _ := regexp.Match(".+=.+", []byte(label))
		if match {
			split := strings.Split(label, "=")
			c.Condition[split[0]] = split[1]
		}
	}
}

func (c *Config) Parse() error {
	c.Condition = make(map[string]string)
	c.Condition["connect_ip"] = c.Host
	c.labelToCondition()

	dbHost, dbBastion, err := c.getHostFromDB()

	// 获取运行用户
	if c.User == "" {
		if dbHost != nil && dbHost.SSHUser != "" {
			c.User = dbHost.SSHUser
		} else {
			c.User = os.Getenv("USER")
		}
	}
	if c.Port == 0 {
		if dbHost != nil && dbHost.SSHPort > 0 {
			c.Port = dbHost.SSHPort
		} else {
			c.Port = 22
		}
	}

	if c.Password == "" && dbHost != nil && dbHost.SSHPassword != "" {
		c.Password = dbHost.SSHPassword
	}

	if c.Key == "" {
		if dbHost != nil && dbHost.SSHKeyFile != "" {
			c.Key = replaceHomeDir(dbHost.SSHKeyFile)
		} else {
			c.Key = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
		}
	}

	if c.KeyPassphrase == "" && dbHost != nil && dbHost.SSHKeyPassphrase != "" {
		c.KeyPassphrase = dbHost.SSHKeyPassphrase
	}

	hostKey, err := readSSHKey(c.Key)
	if err != nil {
		return err
	}
	c.KeyContent = hostKey

	// 校验转发类型
	if c.Type != "local" && c.Type != "remote" {
		return errors.New("proxy type is one of [local, remote]")
	}

	// 处理跳板机
	if c.BastionHost == "" && dbBastion == nil {
		return nil
	}
	if c.BastionHost != "" {
		newBastion, _ := model.QueryBastionByCondition(map[string]string{"connect_ip": c.BastionHost})
		if newBastion != nil {
			dbBastion = newBastion
		}
	} else {
		c.BastionHost = dbBastion.ConnectIP
	}

	if c.BastionPort == 0 {
		if dbBastion != nil && dbBastion.SSHPort > 0 {
			c.BastionPort = dbBastion.SSHPort
		} else {
			c.BastionPort = 22
		}
	}
	if c.BastionUser == "" {
		if dbBastion != nil && dbBastion.SSHUser != "" {
			c.BastionUser = dbBastion.SSHUser
		} else {
			c.BastionUser = os.Getenv("USER")
		}
	}

	if c.BastionPassword == "" && dbBastion != nil {
		c.BastionPassword = dbBastion.SSHPassword
	}
	if c.BastionKey == "" {
		if dbBastion != nil && dbBastion.SSHKeyFile != "" {
			c.BastionKey = replaceHomeDir(dbBastion.SSHKeyFile)
		} else {
			c.BastionKey = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
		}
	}
	if c.BastionKeyPassphrase == "" && dbBastion != nil {
		c.BastionKeyPassphrase = dbBastion.SSHKeyPassphrase
	}

	bastionKey, err := readSSHKey(c.BastionKey)
	if err != nil {
		return err
	}
	c.BastionKeyContent = bastionKey

	return nil
}

func (c *Config) getHostFromDB() (*model.Host, *model.Bastion, error) {
	if !config.GlobalConfig.UsedDB {
		return nil, nil, nil
	}
	host, err := model.QueryHostByCondition(c.Condition)
	if err != nil || host == nil {
		return nil, nil, err
	}
	if host.BastionIP == "" {
		return host, nil, err
	}
	bastion, err := model.QueryBastionByCondition(map[string]string{"connect_ip": host.BastionIP})
	if err != nil {
		return nil, nil, err
	}
	return host, bastion, nil
}

func readSSHKey(file string) (string, error) {
	readFile, err := ioutil.ReadFile(file)
	if err != nil {
		return "", err
	}
	return string(readFile), nil
}

func replaceHomeDir(file string) string {
	if strings.HasPrefix(file, "~") {
		return strings.Replace(file, "~", os.Getenv("HOME"), 1)
	}
	return file
}
